# -*- coding: utf-8 -*-
import time
import csv
import os
import argparse
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

from backend import *
from datasets import *
from utils import *
torch.manual_seed(0)

def ParseArgs():
    parser = argparse.ArgumentParser()
    parser.add_argument('--lmbda', default=1e-3, type=float, help='weighting parameters')
    parser.add_argument('--max_epoch', default=300, type=int)
    parser.add_argument('--backend', default='vgg16', type=str) # vgg16 | resnet18
    parser.add_argument('--dataset_name', default='cifar10', type=str) # cifar10 | mnist
    return parser.parse_args()

class HSPG(Optimizer):

    def __init__(self, params, lr=required, lmbda = required, epsilon=required):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))

        if lmbda is not required and lmbda < 0.0:
            raise ValueError("Invalid lambda: {}".format(lmbda))

        if epsilon is not required and epsilon < 0.0:
            raise ValueError("Invalid epsilon: {}".format(epsilon))
        
        defaults = dict(lr=lr, lmbda=lmbda, epsilon=epsilon)
        super(HSPG, self).__init__(params, defaults)
        self.name = 'HSPG'

    def __setstate__(self, state):
        super(HSPG, self).__setstate__(state)

    def prox_mapping_group_conv(self, x, grad_f, lmbda, lr):
        '''
        Proximal Mapping for next iterate for Omega(x) = sum_{g in G}||[x]_g||_2
        '''
        trial_x = x - lr * grad_f
        delta = torch.zeros_like(x)
        num_kernels, channels, height, width = x.shape
        numer = lr * lmbda
        denoms = torch.norm(trial_x.view(num_kernels, -1), p=2, dim=1)
        coeffs = 1.0 - numer / (denoms + 1e-6) 
        coeffs[coeffs<=0] = 0.0
        coeffs = coeffs.unsqueeze(1).unsqueeze(1).unsqueeze(1)
        trial_x = coeffs * trial_x
        delta = trial_x - x
        return delta

    def proxsg_step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                if is_conv_weights(p.shape):
                    delta = self.prox_mapping_group_conv(p.data, grad_f, group['lmbda'], group['lr'])
                    p.data.add_(1.0, delta)
                else:
                    p.data.add_(-group['lr'], grad_f)
        return loss

    def half_space_step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                state = self.state[p]
                if p.grad is None:
                    continue
                grad_f = p.grad.data
                
                if is_conv_weights(p.shape):
                    num_kernels, channels, height, width = p.shape
                    norm_x = torch.norm(p.data.view(num_kernels, -1), p=2, dim=1)
                    zero_idx = norm_x == 0
                    # print(norm_x)
                    # print(zero_idx, zero_idx.shape)
                    # exit()
                    hat_x = self.gradient_step_conv(p.data, grad_f, group['lr'], group['lmbda'])
                    

                    # do half space projection
                    proj_x = hat_x
                    idx = (torch.bmm(proj_x.view(proj_x.shape[0], 1, -1), p.data.view(p.data.shape[0], -1, 1)).squeeze() \
                        < group['epsilon'] * torch.norm(p.data.view(p.data.shape[0], -1), p=2, dim=1) ** 2)
                    proj_x[idx, ...] = 0.0
                    
                    # fixed non_free variables
                    proj_x.data[zero_idx, ...] = 0.0
                    
                    p.data.copy_(proj_x)           
                else:
                    p.data.add_(-group['lr'], grad_f)
       
        return loss

    def gradient_step_conv(self, x, grad_f, lr, lmbda):
        norms = torch.norm(x.view(x.shape[0], -1), p=2, dim=1)
        return x - lr * grad_f - lr * lmbda * x / norms.unsqueeze(1).unsqueeze(1).unsqueeze(1)

    def gradient_step_uniform(self, x, grad_f, lr, lmbda, group_indexes):
        trial_x = x - lr * grad_f
        for i, group_index in enumerate(group_indexes):
            group_l2_norm = torch.norm(x[:, group_index], p=2)
            trial_x[:, group_index] -= lr * lmbda * x[:, group_index] / group_l2_norm
        return trial_x

    def adjust_epsilon(self, step=0.05):
        print("epsilon update: {}".format(self.param_groups[0]['epsilon']), end=' --> ')
        for group in self.param_groups:
            group['epsilon'] = group['epsilon'] * (1+step)
        print(self.param_groups[0]['epsilon'])

    def adjust_learning_rate(self, epoch):
        if epoch % 75 == 0 and epoch > 0:
            for group in self.param_groups:
                group['lr'] /= float(10)


if __name__ == "__main__":

    args = ParseArgs()
    lmbda = args.lmbda
    max_epoch = args.max_epoch
    backend = args.backend
    dataset_name = args.dataset_name
    alpha = 1e-1
    batch_size = 128
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    trainloader, testloader = Dataset(dataset_name)
    model = Model(backend, device)

    weights = [w for name, w in model.named_parameters() if "weight" in name]
    num_features = sum([w.numel() for w in weights])
    num_samples = len(trainloader) * trainloader.batch_size

    n = num_features
    m = num_samples

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = HSPG(model.parameters(), lr=alpha, lmbda=lmbda, epsilon=0.0)

    # print('Accuracy:', check_accuracy(model, testloader))

    os.makedirs('results', exist_ok=True)
    csvname = 'results/hspg_%s_%s_%E.csv'%(backend, dataset_name, lmbda)
    print('The csv file is %s'%csvname)
    # if os.path.exists(csvname):
    #     print('csvfile exists. Quit the program...')
    #     exit()
        
    csvfile = open(csvname, 'w', newline='')
    fieldnames = ['epoch', 'F_value', 'f_value', 'omega_value', 'sparsity', 'sparsity_tol', 'sparsity_group', 'validation_acc', 'train_time', 'remarks']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=",")
    writer.writeheader()

    alg_start_time = time.time()
    epoch = 0
    # for test
    # model.load_state_dict(torch.load(os.path.join("checkpoints", backend + '_' +dataset_name + '.pt')))
    # accuracy = check_accuracy(model, testloader)
    # epoch = 150
    while True:
        epoch_start_time = time.time()

        if epoch >= max_epoch:
            break

        for index, (X, y) in enumerate(trainloader):
            X = X.to(device)
            y = y.to(device)
            y_pred = model.forward(X)

            f = criterion(y_pred, y)
            optimizer.zero_grad()
            f.backward()
            if epoch < 150:
                optimizer.proxsg_step()
            else:
                optimizer.half_space_step()

            if epoch == 0:
                torch.save(model.state_dict(), os.path.join("checkpoints", backend + '_' +dataset_name + '.pt'))
            if epoch == 150 - 1:
                torch.save(model.state_dict(), os.path.join("checkpoints", backend + '_' +dataset_name + '.pt'))
        epoch += 1
        optimizer.adjust_learning_rate(epoch)
        
        train_time = time.time() - epoch_start_time
        F, f, omega = compute_func_values(trainloader, model, weights, criterion, lmbda)
        sparsity, sparsity_tol, sparsity_group, _ = compute_sparsity(weights)
        accuracy = check_accuracy(model, testloader)

        writer.writerow({'epoch': epoch, 'F_value': F, 'f_value': f, 'omega_value': omega, 'sparsity': sparsity, 'sparsity_tol': sparsity_tol, 'sparsity_group': sparsity_group, 'validation_acc': accuracy, 'train_time': train_time, 'remarks': '%s;%s;%E'%(backend, dataset_name, lmbda)})

        csvfile.flush()
        print("epoch {}: {:2f}seconds ...".format(epoch, train_time))

    alg_time = time.time() - alg_start_time
    writer.writerow({'train_time': alg_time / epoch})

    csvfile.close()

